-- 修改自 openresty 最佳实践中提供的
--      文中提到的链接打不开,只好完整复制过来加以修改
-- Author: chleniang@163.com
--
-- 使用方法:
--      local redis = require("lib.cls_redis")
--      local red = redis:new({
--          host = "127.0.0.1",  -- redis主机 默认"127.0.0.1"
--          port = 6379,     -- redis端口号 默认6379
--          password = "",   -- redis密码 默认空字符串(无密码)
--          db_index = 0,    -- redis使用的库编号 默认0号库
--          timeout = 1000,  -- redis超时时间 单位(ms) 默认1000
--                           -- 此处为方便这个超时时间是 settimeouts(a,b,c) 三个参数的统一值
--          keepalive_time = 60000,  -- 连接池空闲时间 单位(ms) 默认60秒
--          keepalive_size = 100,    -- 连接池大小 默认100
--      })
--
--      local ok, err = red:set('dog', 'dog wang~wang~')
--      if not ok then
--          util.say_line("failed to set dog: ", tostring(err) )
--          return
--      end
--
--      new 完之后直接用,不用关心关闭等操作,连接池有效直接用连接池中的连接
--      支持 pipeline subscribe
--

local redis_c = require("resty.redis")

local ok, new_tab = pcall(require, "table.new")
if not ok or type(new_tab) ~= "function" then
    new_tab = function(narr, nrec) return {} end
end


local _M = new_tab(0, 155)
_M._VERSION = "0.01"


local commands = {
    "append", "auth", "bgrewriteaof",
    "bgsave", "bitcount", "bitop",
    "blpop", "brpop",
    "brpoplpush", "client", "config",
    "dbsize",
    "debug", "decr", "decrby",
    "del", "discard", "dump",
    "echo",
    "eval", "exec", "exists",
    "expire", "expireat", "flushall",
    "flushdb", "get", "getbit",
    "getrange", "getset", "hdel",
    "hexists", "hget", "hgetall",
    "hincrby", "hincrbyfloat", "hkeys",
    "hlen",
    "hmget", "hmset", "hscan",
    "hset",
    "hsetnx", "hvals", "incr",
    "incrby", "incrbyfloat", "info",
    "keys",
    "lastsave", "lindex", "linsert",
    "llen", "lpop", "lpush",
    "lpushx", "lrange", "lrem",
    "lset", "ltrim", "mget",
    "migrate",
    "monitor", "move", "mset",
    "msetnx", "multi", "object",
    "persist", "pexpire", "pexpireat",
    "ping", "psetex", "psubscribe",
    "pttl",
    "publish", --[[ "punsubscribe", ]] "pubsub",
    "quit",
    "randomkey", "rename", "renamenx",
    "restore",
    "rpop", "rpoplpush", "rpush",
    "rpushx", "sadd", "save",
    "scan", "scard", "script",
    "sdiff", "sdiffstore",
    "select", "set", "setbit",
    "setex", "setnx", "setrange",
    "shutdown", "sinter", "sinterstore",
    "sismember", "slaveof", "slowlog",
    "smembers", "smove", "sort",
    "spop", "srandmember", "srem",
    "sscan",
    "strlen", --[[ "subscribe",  ]] "sunion",
    "sunionstore", "sync", "time",
    "ttl",
    "type", --[[ "unsubscribe", ]] "unwatch",
    "watch", "zadd", "zcard",
    "zcount", "zincrby", "zinterstore",
    "zrange", "zrangebyscore", "zrank",
    "zrem", "zremrangebyrank", "zremrangebyscore",
    "zrevrange", "zrevrangebyscore", "zrevrank",
    "zscan",
    "zscore", "zunionstore", "evalsha"
}


local mt = { __index = _M }


local function is_redis_null(res)
    if type(res) == "table" then
        for k, v in pairs(res) do
            if v ~= ngx.null then
                return false
            end
        end
        return true
    elseif res == ngx.null then
        return true
    elseif res == nil then
        return true
    end

    return false
end

function _M.connect_mod(self, redis)
    redis:set_timeouts(self.timeout, self.timeout, self.timeout)
    local conn_res, conn_err = redis:connect(self.host, self.port)
    if not conn_res then
        return {}, conn_err
    end

    if self.password ~= "" then
        local count, reusederr = redis:get_reused_times()
        if 0 == count then
            local res, autherr = redis:auth(self.password)
            if not res then
                return {}, "redis authenticate failed: " .. autherr
            end
        elseif reusederr then
            return {}, "failed to get reused times: " .. reusederr
        end
    end

    if self.db_index > 0 then
        local ok, err = redis:select(self.db_index)
        if not ok then
            return {}, "redis SELECT DB failed: " .. err
        end
    end
    return conn_res, conn_err
end

function _M.set_keepalive_mod(redis)
    if _M.db_index ~= 0 then
        redis:select(0)
    end
    return redis:set_keepalive(_M.keepalive_time, _M.keepalive_size)
end

function _M.init_pipeline(self)
    self._reqs = {}
end

function _M.commit_pipeline(self)
    local reqs = self._reqs

    if nil == reqs or 0 == #reqs then
        return {}, "no pipeline"
    else
        self._reqs = nil
    end

    local redis, err = redis_c:new()
    if not redis then
        return nil, err
    end

    local ok, err = self:connect_mod(redis)
    if not ok then
        return {}, err
    end

    redis:init_pipeline()
    for _, vals in ipairs(reqs) do
        local fun = redis[vals[1]]
        table.remove(vals, 1)

        fun(redis, unpack(vals))
    end

    local results, err = redis:commit_pipeline()
    if not results or err then
        return {}, err
    end

    if is_redis_null(results) then
        results = {}
        ngx.log(ngx.WARN, "is null")
    end

    self.set_keepalive_mod(redis)

    for i, value in ipairs(results) do
        if is_redis_null(value) then
            results[i] = nil
        end
    end

    return results, err
end

---
-- 一个潜在的问题是，调用了 unsubscribe 之后，Redis 对象里面有可能还遗留没被读取的数据。
-- 在这种情况下，无法直接通过 set_keepalive_mod 复用连接。什么时候会发生这样的情况呢？
-- 当 Redis 对象处于 subscribe 状态时，Redis 会给它推送订阅的消息，然后我们通过 read_reply 把消息读出来。调用 unsubscribe 的时候，只是退订了对应的频道，并不会把当前接收到的数据清空。如果要想复用该连接，我们就需要保证清空当前读取到的数据，保证它是干净的。
--  如下代码示例:
--      local res, err = red:unsubscribe("ch")
--      if not res then
--          ngx.log(ngx.ERR, err)
--          return
--      else
--          -- redis 推送的消息格式，可能是
--          -- {"message", ...} 或
--          -- {"unsubscribe", $channel_name, $remain_channel_num}
--          -- 如果返回的是前者，说明我们还在读取 Redis 推送过的数据
--          if res[1] ~= "unsubscribe" then
--              repeat
--                  -- 需要抽空已经接收到的消息
--                  res, err = red:read_reply()
--                  if not res then
--                      ngx.log(ngx.ERR, err)
--                      return
--                  end
--              until res[1] == "unsubscribe"
--          end
--          -- 现在再复用连接，就足够安全了
--          self.set_keepalive_mod(redis)
--      end
function _M.subscribe(self, channel)
    local redis, err = redis_c:new()
    if not redis then
        return nil, err
    end

    local ok, err = self:connect_mod(redis)
    if not ok or err then
        return nil, err
    end

    local res, err = redis:subscribe(channel)
    if not res then
        return nil, err
    end

    -- 封装成一个函数，开始
    local function do_read_func(do_read)
        if do_read == nil or do_read == true then
            res, err = redis:read_reply()
            if not res then
                return nil, err
            end
            return res
        end

        redis:unsubscribe(channel)
        self.set_keepalive_mod(redis)
        return
    end
    -- 结束

    return do_read_func -- 返回上面封装的函数
end

local function do_command(self, cmd, ...)
    if self._reqs then
        table.insert(self._reqs, { cmd, ... })
        return
    end

    local redis, err = redis_c:new()
    if not redis then
        return nil, err
    end

    local ok, err = self:connect_mod(redis)
    if not ok or err then
        return nil, err
    end

    local fun = redis[cmd]
    local result, err = fun(redis, ...)
    if not result or err then
        return nil, err
    end

    if is_redis_null(result) then
        result = nil
    end

    self.set_keepalive_mod(redis)

    return result, err
end


for i = 1, #commands do
    local cmd = commands[i]
    _M[cmd] =
    function(self, ...)
        return do_command(self, cmd, ...)
    end
end


function _M.new(self, opts)
    opts = opts or {}
    local host = opts.host or "127.0.0.1"
    local port = opts.port or 6379
    local password = opts.password or ""
    local db_index = opts.db_index or 0
    local timeout = opts.timeout or 1000
    local keepalive_time = opts.keepalive_time or 60000
    local keepalive_size = opts.keepalive_size or 100

    return setmetatable({
            host           = host,
            port           = port,
            password       = password,
            db_index       = db_index,
            timeout        = timeout,
            keepalive_time = keepalive_time,
            keepalive_size = keepalive_size,
            _reqs          = nil
        },              mt)
end

return _M
